import numpy as np
from scipy.linalg import svd, eigh


def A2logit(A: np.ndarray, tau: float = None, M: float = 4) -> np.ndarray:
    """
    Estimate the logit matrix by singular value thresholding (SVT).

    Parameters
    ----------
    A : np.ndarray
        Binary adjacency matrix, shape (n, n).
    tau : float, optional
        Cutoff for SVT. Default = sqrt(n * p_ave).
    M : float, optional
        Lower bound of logits is -M. Default = 4.

    Returns
    -------
    logit : np.ndarray
        Estimated logit matrix, shape (n, n).
    """
    A = np.asarray(A, dtype=float)
    n = A.shape[0]

    # Ĭ tau
    p_ave = np.sum(A) / (n * (n - 1))
    if tau is None:
        tau = np.sqrt(n * p_ave)

    # SVD
    U, s, Vt = np.linalg.svd(A, full_matrices=False)
    s = np.where(s > tau, s, 0.0)  # ֵ
    D = np.diag(s)
    temp = U @ D @ Vt

    # ƷΧ
    lower = 1 / (1 + np.exp(M))
    upper = 1 / (1 + np.exp(-M))
    temp = np.clip(temp, lower, upper)

    # ת logit
    logit = -np.log(1 / temp - 1)

    return logit


def center(X: np.ndarray) -> np.ndarray:

    X = np.asarray(X, dtype=float)
    n = X.shape[0]

    # Ļ
    X = X - np.ones((n, 1)) @ (np.ones((1, n)) @ X / n)
    # ٶĻ
    X = X - (X @ np.ones((n, 1)) / n) @ np.ones((1, n))

    return X


def init_SVT(A: np.ndarray, k: int):
    """
    Initialization by Singular Value Thresholding (SVT).
    Algorithm 3 in Ma and Ma, 2017, without covariate matrix X.

    Parameters
    ----------
    A : np.ndarray
        Binary adjacency matrix, shape (n, n).
    k : int
        Dimension of the latent vector.

    Returns
    -------
    Z0 : np.ndarray
        Initial estimate of the latent vector matrix, shape (n, k).
    alpha0 : np.ndarray
        Initial estimate of the degree heterogeneity, shape (n, ).
    """
    n = A.shape[0]

    # === estimate logit matrix ===
    logit = A2logit(A)          # ֮ǰʵֵ A2logit
    logit = (logit + logit.T) / 2  # Գƻ

    # === solving least squares ===
    logit_sum = np.sum(logit, axis=1, keepdims=True)  #  (n,1)
    alpha0 = (logit_sum - 0.5 * np.ones((n, 1)) * np.mean(logit_sum)) / n  # (n,1)

    # === finding Z0 ===
    R = logit - alpha0 @ np.ones((1, n)) - np.ones((n, 1)) @ alpha0.T
    G0 = center(R)  # ֮ǰʵֵ center

    # Գƻֽ
    G0_sym = (G0 + G0.T) / 2
    D, U = np.linalg.eigh(G0_sym)  # eigh רڶԳƾ
    order = np.argsort(D)[::-1]    # ӴС
    D = D[order]
    U = U[:, order]

    # ȡǰ k 
    Z0 = U[:, :k] @ np.diag(np.sqrt(D[:k]))

    return Z0, alpha0.squeeze()